/*
    Copyright 2006-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Definitions of nbody_data_grid methods. \ingroup makegrid

// $Id$

#include "config.h"
#include "sphgrid.h"
#include "constants.h"
#include "boost/thread/thread.hpp"
#include <hpm.h> 
#include <algorithm>
#include <iostream>
#include "config.h"

#ifndef HAVE_COPY_IF
#include "copy_if.h"
#else
using std::copy_if;
#endif

using namespace std;

/** Assignment operator, made complicated by the references semantics
    of the Blitz copy constructor.  If L_lambda is shared with any
    other array, it's first made unique.  */
mcrx::nbody_data& mcrx::nbody_data::operator= (const nbody_data& rhs)
{
  m_g_ = rhs.m_g_;
  p_g_ = rhs.p_g_;
  L_bol_ = rhs.L_bol_;
  SFR = rhs.SFR;
  m_metals_ = rhs.m_metals_;
  m_s_ = rhs.m_s_; 
  p_s_ = rhs.p_s_;
  m_m_s = rhs.m_m_s; 
  age_m = rhs.age_m;
  age_l = rhs.age_l; 
  gas_temp_m = rhs.gas_temp_m;
  gas_teff_m = rhs.gas_teff_m;
  
  // since the copy constructor has reference semantics, we have to
  // make sure we don't mindlessly assign to a bunch of data here
  L_lambda.makeUnique();
  if (L_lambda.size() != rhs.L_lambda.size())
    L_lambda.resize(rhs.L_lambda.size() );
  L_lambda = rhs.L_lambda;

  return *this;
}


bool mcrx::tolerance_checker::unify_cell_p (const nbody_data_cell& sum,
					    const nbody_data_cell& sumsq,
					    int n, const T_cell& c) const
{
  const T_float m_g_mean = sum.m_g()/n;
  const T_float m_g_stdev2 = sumsq.m_g()/n- m_g_mean*m_g_mean;
  const T_float m_g_fractional = m_g_stdev2 != 0?
    m_g_stdev2/(m_g_mean*m_g_mean) :0;

  const T_float m_m_mean = sum.m_metals()/n;
  const T_float m_m_stdev2 = sumsq.m_metals()/n- m_m_mean*m_m_mean;
  const T_float m_m_fractional = m_m_stdev2 != 0?
    m_m_stdev2/(m_m_mean*m_m_mean) :0;

  const T_float L_bol_mean = sum.L_bol()/n;
  const T_float L_bol_stdev2 = sumsq.L_bol()/n-L_bol_mean*L_bol_mean;
  const T_float L_bol_fractional = L_bol_stdev2 != 0?
    L_bol_stdev2/(L_bol_mean*L_bol_mean) :0;

  // estimate of column density of metals = rho*l = m/V*l = m*V^-2/3
  const T_float col_est = (sum.m_g()*gas_metallicity_ + sum.m_metals()) *
    pow(c.volume(), -2./3);

  return ((( m_g_fractional < tolerance_.m_g()*tolerance_.m_g()) ||
	   ( m_g_stdev2 < tolerance_absolute_.m_g()*tolerance_absolute_.m_g() ))
	  &&
	  (( m_m_fractional < tolerance_.m_metals()*tolerance_.m_metals()) ||
	   ( m_m_stdev2 < tolerance_absolute_.m_metals()*tolerance_absolute_.m_metals() ))
	  &&
	  (( L_bol_fractional < tolerance_.L_bol()*tolerance_.L_bol()) ||
	    ( L_bol_stdev2 < 
	      tolerance_absolute_.L_bol()*tolerance_absolute_.L_bol()*L_bol_tot_*L_bol_tot_ ))
	  &&
	  ( col_est < max_metal_column_ )
	   );
}


bool
mcrx::tolerance_checker::refine_cell_p (const check_particle_overlap_p& cpo,
					const T_cell& c) const
{
  // estimate of column density, the maximum of m/r^3 of the
  // overlapping particles, to be conservative multiplied by 20.
  const double col_est = 20*pow(c.volume(), 1./3)*cpo.max_rho()*
    (3./(4*constants::pi));

  return ( col_est > max_metal_column_ );
}

void mcrx::tolerance_checker::print_tolerances() const
{
  cout << "  Fractional unification tolerance: " << tolerance_.L_bol() << " L_bol, "
       << tolerance_.m_g() << " gas, " << tolerance_.m_metals() 
       << " metals" << "\n";
  cout << "  Absolute unification tolerance: " << tolerance_absolute_.L_bol() << " W, "
       << tolerance_absolute_.m_g() << " M_Sun gas, " << tolerance_absolute_.m_metals() 
       << " M_Sun metals" << "\n";
  cout << "  Resolve metal column density of : " << max_metal_column_ 
       << " M_Sun" << "\n\n";
}

/** Recursively refines a grid to a resolution consistent with the
    supplied list of particles.  Because of threading issues, the work
    for each cell is done by the recursive_refinement_body function.  */
mcrx::refinement_accuracy_data
mcrx::nbody_data_grid::recursive_refine (T_grid& g, 
					 const vector<T_particle*>& pl,
					 int level, vector<int>& stats)
{
  // we must initialize the structure in this way because of blitz
  // array semantics
  T_grid::local_iterator c = g.lbegin();
  refinement_accuracy_data racc = recursive_refine_body (c, pl, level, stats);
  ++c;
  // and now loop over the other cells
  for (; c != g.lend(); ++c) {
     racc+= recursive_refine_body (c, pl, level, stats);
  }

  return racc;
}


/** Projects the quantities in a list of particles onto a grid cell.
    This is the calculation of how much density, luminosity, etc., of
    the particles fall within the specified grid cell.  */
mcrx::refinement_accuracy_data
mcrx::nbody_data_grid::project_particles (T_cell& c,
					  const vector<T_particle*> & pl)
{
  // we initialize the value using the copy constructor when we
  // allocate, this should take care of any issues if the data is
  // a blitz array that need to be sized property
  assert (all((abs(c.getmax()- c.getmin())/c.getsize()-1) < 1e-5));
  T_particle::T_data d (data_zero);

  // loop over particles
  for (vector<T_particle*>::const_iterator p = pl.begin(); p != pl.end(); ++p) {
    const T_float projection = (*p)->project(c);
    assert (projection >= 0);
    assert (projection <= 1);
    d.add_to ((*p)->data (), projection, c.volume());
  }

  c.set_data(new T_data (d));

  // return refinement_accuracy_data for this cell, which starts the
  // accumulation chain
  return refinement_accuracy_data (*c.data(), *c.data()* *c.data(), 1, true);
}


/** Checks if a sub grid can be unified into one cell.  If the grid
    refinement criteria are fulfilled, the sub grid in c is removed
    and c becomes a leaf cell. Notice non-const reference being
    modified.  */
void mcrx::nbody_data_grid::unrefine_if_possible (T_cell& c,
						  refinement_accuracy_data& racc)
{
  assert (!c.is_leaf());
  if (c.is_leaf())
    // this should never happen
    return;

  // Now evaluate if we are within tolerances to unrefine
  if (tol_checker_->unify_cell_p (racc.sum, racc.sumsq, racc.n, c)) {
    // yes, unrefine.  This "constructor" returns a T_data object that
    // represents the unification
    T_data*const d =  new T_data (T_data::unification (racc.sum, racc.n)) ;
    c.unrefine(d);
  }
  else
    // Mark that the sub cells were not unified
    racc.all_leaves = false;
}


/** Takes the cells from the list and does the refinement. This
    function is called by the thread_start object for each of the
    threads.  */
void 
mcrx::nbody_data_grid::pop_cell_and_refine()
{
  cout << "thread starting\n";
  //hpm::hpmtstart (1, "recursive refine ");
  vector<int> stats;
  bool first = true;
  while (true) {

    // pop another cell off the stack, exit if empty
    T_cell* current_cell;
    int start_level;
    {
      // open scope for locking stack_mutex
      boost::mutex::scoped_lock stack_lock (cell_stack_mutex);
      if (!stats.empty())
	for (int i = 0; i < stats.size(); ++i)
	  creation_stats [i]+= stats [i];
      if (cell_stack.empty())
	break;
      current_cell = cell_stack.back().first;
      start_level  = cell_stack.back().second;
      cell_stack.pop_back();
      if (!first) {
	ctr++;
	// cout << endl;
      }
      first = false;
    }
    stats.assign(max_level, 0 );

    check_particle_overlap_p cpo (*current_cell);
    std::for_each (particle_list_.begin(), particle_list_.end(), cpo);
    // print cells that seem to still be high-workload
    // the estimated number of ADDITIONAL refinements needed
    int estimated = int(floor(log (dot (current_cell->getsize(), 
					current_cell->getsize()) /
				   (size_factor*size_factor*
				    cpo.min_size ()*cpo.min_size ()))/
			      (2*log(2.))+2));
    // correct by checking against max_levels
    estimated = std::min (estimated, max_level - start_level);
    if (estimated > work_chunk_levels)
      cout << '\t' << cpo.min_size()<< '\t' << estimated << '\t'
	   << particle_list_.size()<< endl;
    recursive_refine_body (current_cell, particle_list_, start_level, stats);
  } // while 

  {
    // make final counter increment
    boost::mutex::scoped_lock stack_lock (cell_stack_mutex);
    ctr++;
    cout << endl;
  }
  
  //hpm::hpmtstop (1);
  cout << "thread exiting\n";
  cout.flush();
}


/** Recursively refines a grid cell. The refinement is done if the
    particles are small enough that this is deemed necessary.  It then
    recursively calls recursive_refine and then sees if the cells can
    also be unified again.  If the refinement is not necessary, the
    particles are projected onto the cell.  */
mcrx::refinement_accuracy_data 
mcrx::nbody_data_grid::recursive_refine_body (T_cell* c, 
					      const vector<T_particle*>& pl,
					      int level,
					      vector<int>& stats )
{
  // Put the particles that overlap with the cell c in the vector
  vector<T_particle*> pl_next;
  check_particle_overlap_p cpo (*c);
  copy_if (pl.begin(), pl.end(), back_inserter (pl_next), cpo);

  // Now check if we SHOULD refine the cell. We refine if we are not
  // at max level and if either the size factor or the absolute
  // tolerance indicate we need more refinement.
  if ((level < max_level) &&
      ((size_factor*size_factor*cpo.min_size()*cpo.min_size() <
	dot (c->getsize(), c->getsize())) || 
       (tol_checker_->refine_cell_p(cpo, *c)) ) ) {
    // yes, refine

    assert (c->is_leaf());// what would happen if this is not the case??
    c->refine();
    
    // and recursively refine those cells
    refinement_accuracy_data racc =
      recursive_refine (*c->sub_grid(), pl_next, level + 1, stats); 
    if (racc.all_leaves) {
      // recursive_refine returned true, which means that that grid
      // contains only leaf cells.  We can try to unrefine them
      unrefine_if_possible (*c, racc);
    }
    return racc;
  }
  else {
    // no, don't refine.  Project the particles
    ++stats [level - 1]; 
    return project_particles (*c, pl_next);
  }
}

/** Pre-refines grid cells which are estimated to be high-workload, to
    improve load balancing.  This function goes through the initial
    cells and subdivides them as they are put in the cell_stack if the
    estimated number of refinements needed is greater than
    work_chunk_levels.  This improves load balancing by ensuring that
    one thread doesn't get stuck with the core of the galaxy, taking
    many times longer than all the other threads.  */
void
mcrx::nbody_data_grid::process_balance_queue_cell ()
{
  cout << "Thread starting\n";
  //hpm::hpmtstart(10,"balancing");

  // get first cell
  balance_queue_data current_cell;
  {
    boost::mutex::scoped_lock stack_lock (cell_stack_mutex);
    if (balance_queue.empty())
      return;
    current_cell = balance_queue.back();
    balance_queue.pop_back();
    ctr++;
  }
  
  while (true) {
    // now process this cell
    T_cell*const c = current_cell.cell;
    boost::shared_ptr<vector<T_particle*> > pl_next(new vector<T_particle*>());
    check_particle_overlap_p cpo (*c);
    copy_if (current_cell.particle_list_->begin(),
	     current_cell.particle_list_->end(),
	     back_inserter (*pl_next), cpo);
    // estimate the number of refinements needed
    int estimated = int(floor(log (dot (c->getsize(),
					 c->getsize()) /
				    (size_factor*size_factor*
				     cpo.min_size ()*cpo.min_size ()))/
			       (2*log(2.))+2)); 
    estimated = std::min (estimated, max_level - current_cell.level);

    if (estimated > work_chunk_levels) {
      // This is a "high workload cell", subdivide it
      assert (c->is_leaf());
      c->refine();
    }

    // and now add the new cells to the appropriate list
    {
      boost::mutex::scoped_lock stack_lock (cell_stack_mutex);
      
      if (estimated > work_chunk_levels) {
	// and add the sub cells to the balance queue
	for (local_iterator cc = c->sub_grid()->lbegin ();
	     cc != c->sub_grid()->lend (); ++cc)
	  balance_queue.push_back(balance_queue_data (cc, pl_next,
						      current_cell.level + 1));
      }
      else {
	// a "low workload cell", put it in the cell_stack
	cell_stack.push_back(std::pair<T_cell*, int> (c, current_cell.level) );
      }
      
      // and get the next cell
      if (balance_queue.empty())
	break;
      current_cell = balance_queue.back();
      balance_queue.pop_back();
      ctr++;
    }
  }
  
  //hpm::hpmtstop (10);
  cout << "Thread exiting\n";
}


/** Creates the "load balanced" list of cells by pre-refining cells
    that are likely to have a high workload.  This is done with
    multiple threads, this function sets up the information for and
    starts the threads that create the balanced work cells.  */
void mcrx::nbody_data_grid::create_balanced_queue ()
{
  // if we are not using threads, this step is pointless and we should
  // proceed immediately to the next step
  if ((n_threads<=1) || (work_chunk_levels <= 0)) {
      for (local_iterator i = lbegin (); i != lend (); ++i)
	cell_stack.push_back(std::pair<T_cell*, int> (i, 1) );
      return;
  }
  
  // put cells in balance_queue 
  cout << "Creating balanced work chunks" << endl;
  for (local_iterator i = lbegin (); i != lend (); ++i)
    balance_queue.push_back
      (balance_queue_data 
       (i,boost::shared_ptr<vector<T_particle*> >
	(new vector<T_particle*> (particle_list_)), 1));

  // spawn threads
  ctr = 0;
  boost::thread_group threads;
  for (int i = 0; i < n_threads; ++i)
    //the thread_start object calls process_balance_queue_cell
    threads.create_thread(balance_thread_start (this));
  threads.join_all();
}
    

/** The top-level function that does the grid building.  If we are
    running with several threads, then the "load balancing" is done.  */
void mcrx::nbody_data_grid::build_grid ()
{
  creation_stats.assign(max_level, 0 );
  cout << "Starting with " << n_cells () << " initial grid cells" << endl;
  create_balanced_queue ();
  
  cout << "Now beginning work on " << n_cells ()
       << " balanced cells" << endl;
  ctr = 0;
  if (n_threads > 1) {
    // spawn threads
    //hpm::hpmtstart (11, "master thread grid build");
    boost::thread_group threads;
    for (int i = 0; i < n_threads; ++i)
      //the thread_start object calls pop_cell_and_refine ()
      threads.create_thread(thread_start (this));
    threads.join_all();
    //hpm::hpmtstop (11);
  }
  else // no threads, just call directly
    pop_cell_and_refine ();
}
